feature(pdd): enable P/D disaggregation with NIXL host KV transfer#477
feature(pdd): enable P/D disaggregation with NIXL host KV transfer#477rebel-ykchoi wants to merge 4 commits intodevfrom
Conversation
wire vLLM KV transfer to a RBLN-specific NIXL connector and host-side
buffers so prefill/decode can run on separate engines with H2H transfer.
KV connector / registration
- add RblnNixlConnector (scheduler/worker) extending upstream NixlConnector:
- register connector name "RblnNixlConnector" in kv_connector factory.
Platform
- expose NIXL hints: get_nixl_supported_devices (rbln -> cpu) and
get_nixl_memory_type ("DRAM").
Scheduler (rbln_scheduler.py)
- handle kv_consumer request to be scheduled with other requests in decode
stage
Model runner (rbln_model_runner.py)
- override maybe_get_kv_connector_output(..., wait_for_save)
using last prefill chunk.
- replace generic copy_kv_blocks with rbln_copy_kv_blocks using runtime
_update_kv_cache / _fetch_kv_cache
- bind_kv_cache_name + per-layer names for mark_static_address when compiling.
Attention backend (flash_attention.py)
- Report backend name as FLASH_ATTN for upstream compatibility.
Examples
- add experimental examples/experimental/pd_disaggregation/toy_proxy_server.py
(FastAPI proxy routing chat completions to prefill vs decode HTTP backends).
baa04f8 to
8c11e55
Compare
There was a problem hiding this comment.
Pull request overview
This PR wires vLLM KV transfer to an RBLN-specific NIXL connector and host-side buffers to enable prefill/decode disaggregation (with H2H KV transfer), plus related scheduler/model-runner/attention-backend integration and an E2E accuracy test harness.
Changes:
- Add
RblnNixlConnectorand register it via a connector factory import hook. - Update worker/model-runner/scheduler and FlashAttention metadata plumbing to support P/D disaggregation and KV connector lifecycle.
- Add an end-to-end NIXL integration test setup (proxy server + lm-eval accuracy test script).
Reviewed changes
Copilot reviewed 12 out of 18 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| vllm_rbln/v1/worker/utils.py | Add bind_kv_cache_name helper to name KV cache buffers for compilation/static address marking. |
| vllm_rbln/v1/worker/rbln_worker.py | Move KV-transfer init to initialize_from_config, add handshake metadata helper, add shutdown hook. |
| vllm_rbln/v1/worker/rbln_model_runner.py | Integrate KV connector output handling, preemption handling, host-buffer KV copy ops, static KV cache naming. |
| vllm_rbln/v1/core/rbln_scheduler.py | Adjust scheduling flow to allow remote-KV consumers to be batched with decode requests without mixing local prefill. |
| vllm_rbln/v1/attention/backends/flash_attention.py | Switch prefill/decode detection to explicit is_prefill param; report backend name as FLASH_ATTN. |
| vllm_rbln/platform.py | Expose NIXL device/memory hints for RBLN platform. |
| vllm_rbln/distributed/kv_transfer/kv_connector/v1/rbln_nixl_connector.py | Implement RBLN-specific NIXL connector scheduler/worker behavior and host transfer buffers. |
| vllm_rbln/distributed/kv_transfer/kv_connector/factory.py | Register connector name RblnNixlConnector in the upstream factory. |
| vllm_rbln/init.py | Ensure connector factory registration runs when ops are registered. |
| tests/torch_compile/e2e/v1/kv_connector/nixl_integration/toy_proxy_server.py | Add a proxy server to route prefill to prefiller instances and streaming decode to decoder instances. |
| tests/torch_compile/e2e/v1/kv_connector/nixl_integration/test_accuracy.py | Add an lm-eval based accuracy test for the disaggregated setup. |
| tests/torch_compile/e2e/v1/kv_connector/nixl_integration/run_accuracy_test.sh | Add a runner script that launches prefill/decode instances + proxy and executes the accuracy test. |
(various __init__.py files) |
Package scaffolding for the new connector modules and tests. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def sleep(self, level: int = 1) -> None: | ||
| logger.warning("sleep mode is not supported on RBLN, ignore it.") | ||
| pass | ||
|
|
||
| def wake_up(self, tags: list[str] | None = None) -> None: | ||
| logger.warning("sleep mode is not supported on RBLN, ignore it.") | ||
| pass | ||
|
|
There was a problem hiding this comment.
initialize_cache() was removed, but it’s still used by existing unit tests (e.g. tests/torch_compile/unit/v1/worker/test_rbln_worker.py::TestInitializeCache). This will break the worker interface expected by tests (and likely callers). Please restore initialize_cache() (or update callers/tests consistently) to keep cache_config.num_gpu_blocks/num_cpu_blocks configurable.
| if direction == "h2d": | ||
| kv_caches = src_kv_caches | ||
| copy_fn = runtime._update_kv_cache | ||
| else: | ||
| kv_caches = dst_kv_caches | ||
| copy_fn = runtime._fetch_kv_cache | ||
|
|
||
| for idx in src_block_ids: | ||
| for kv_name, kv_cache in kv_caches.items(): | ||
| block_size = kv_cache.shape[-2] | ||
| copy_fn(kv_cache.data_ptr(), idx, 0, block_size, kv_name) |
There was a problem hiding this comment.
rbln_copy_kv_blocks ignores dst_block_ids and always uses idx from src_block_ids when calling the runtime copy function. This breaks the expected src↔dst mapping when blocks need to be copied to different indices (e.g., compaction/relocation), and it also means dst_kv_caches is never used for the h2d case. Consider iterating with for src_id, dst_id in zip(src_block_ids, dst_block_ids) and passing the correct destination block id (and source/destination buffer) to the runtime API.
| if direction == "h2d": | |
| kv_caches = src_kv_caches | |
| copy_fn = runtime._update_kv_cache | |
| else: | |
| kv_caches = dst_kv_caches | |
| copy_fn = runtime._fetch_kv_cache | |
| for idx in src_block_ids: | |
| for kv_name, kv_cache in kv_caches.items(): | |
| block_size = kv_cache.shape[-2] | |
| copy_fn(kv_cache.data_ptr(), idx, 0, block_size, kv_name) | |
| for src_id, dst_id in zip(src_block_ids, dst_block_ids): | |
| if direction == "h2d": | |
| kv_caches = src_kv_caches | |
| block_id = dst_id | |
| copy_fn = runtime._update_kv_cache | |
| else: | |
| kv_caches = dst_kv_caches | |
| block_id = src_id | |
| copy_fn = runtime._fetch_kv_cache | |
| for kv_name, kv_cache in kv_caches.items(): | |
| block_size = kv_cache.shape[-2] | |
| copy_fn( | |
| kv_cache.data_ptr(), block_id, 0, block_size, kv_name | |
| ) |
|
|
||
| response = await client_info["client"].post( | ||
| endpoint, json=req_data, headers=headers |
There was a problem hiding this comment.
httpx.AsyncClient is configured with base_url=.../v1, but requests are made with endpoints like /completions and /chat/completions (leading slash). In httpx, a leading slash treats the request URL as absolute-path and will drop the /v1 prefix from base_url, so this proxy may call /completions on the backend instead of /v1/completions. Use relative paths (e.g., completions) or remove /v1 from base_url and keep /v1/... in the endpoints consistently (also applies to stream_service_response).
| response = await client_info["client"].post( | |
| endpoint, json=req_data, headers=headers | |
| normalized_endpoint = endpoint.lstrip("/") | |
| response = await client_info["client"].post( | |
| normalized_endpoint, json=req_data, headers=headers |
…ccuracy_test.sh Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
🚀 Summary of Changes
wire vLLM KV transfer to a RBLN-specific NIXL connector and host-side buffers so prefill/decode can run on separate engines with H2H transfer.
KV connector / registration
Platform
Scheduler (rbln_scheduler.py)
Model runner (rbln_model_runner.py)
Attention backend (flash_attention.py)
Tests
📌 Related Issues / Tickets
✅ Type of Change
release)feature)model)core)fix)perf)refactor)docs)other): please describe🧪 How to Test
.........📸 Screenshots / Logs (if applicable)
📋 Checklist
💬 Notes